-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ptnn4both datatypes and alignment tests #1827
Conversation
- Now you can just replace the Pytorch model structure to run a NN model. | ||
|
||
We provide an example to demonstrate the effectiveness of the current design. | ||
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158 dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/model/test_general_nn.py
Outdated
for ds, model in reversed(list(zip((tsds, tbds), model_l))): | ||
model.fit(ds) # It works | ||
model.predict(ds) # It works | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The break should be removed
tests/model/test_general_nn.py
Outdated
), | ||
] | ||
|
||
for ds, model in reversed(list(zip((tsds, tbds), model_l))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove the reverse.
from torch.nn import DataParallel | ||
|
||
|
||
class GeneralPTNN(Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class should be removed
tests/model/test_general_nn.py
Outdated
|
||
|
||
if __name__ == "__main__": | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may not work...
tests/model/test_general_nn.py
Outdated
|
||
import unittest | ||
|
||
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include in the test.
train: [2008-01-01, 2014-12-31] | ||
valid: [2015-01-01, 2016-12-31] | ||
test: [2017-01-01, 2020-08-01] | ||
step_len: 20 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this parameter
Support Pytorch_nn for processing time-series data
Description
We create a new class GeneralPTNN and a test file for it.
Motivation and Context
How Has This Been Tested?
pytest qlib/tests/test_all_pipeline.py
under upper directory ofqlib
.Screenshots of Test Results (if appropriate):
Pipeline test:
Your own tests:
Types of changes